{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "## This file performs alternative machine learning models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Intel(R) Extension for Scikit-learn* enabled (https://github.com/uxlfoundation/scikit-learn-intelex)\n"
     ]
    }
   ],
   "source": [
    "from sklearnex import patch_sklearn\n",
    "patch_sklearn()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from pandas.tseries.offsets import *\n",
    "from tqdm import tqdm\n",
    "from functools import reduce\n",
    "import statsmodels.api as sm\n",
    "import scipy.stats as stats\n",
    "from linearmodels import PanelOLS\n",
    "\n",
    "from functions import utils\n",
    "from functions import summary2\n",
    "\n",
    "from sklearn.model_selection import GridSearchCV\n",
    "from sklearn.ensemble import RandomForestRegressor\n",
    "from sklearn.linear_model import LinearRegression, Lasso, ElasticNet\n",
    "from sklearn.cross_decomposition import PLSRegression\n",
    "from sklearn.metrics import mean_absolute_error, r2_score, mean_squared_error\n",
    "from sklearn.neural_network import MLPRegressor\n",
    "from lightgbm import LGBMRegressor\n",
    "import lightgbm as lgb\n",
    "from sklearn.model_selection import ParameterGrid\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\", category=DeprecationWarning)\n",
    "\n",
    "plt.rcParams['font.sans-serif']=['Times New Roman']\n",
    "plt.rcParams.update({'font.size':13})\n",
    "plt.rcParams['xtick.direction'] = 'in'\n",
    "plt.rcParams['ytick.direction'] = 'in'\n",
    "plt.rcParams['grid.color'] = 'gray'\n",
    "plt.rcParams['grid.linestyle'] = '--'\n",
    "%config InlineBackend.figure_format = 'retina'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ML Forecasts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "ratio_chars = ['CAPEI', 'bm',\n",
    "       'evm', 'pe_exi', 'pe_inc', 'ps', 'pcf',\n",
    "       'dpr', 'npm', 'opmbd', 'opmad', 'gpm', 'ptpm', 'cfm', 'roa', 'roe',\n",
    "       'roce', 'efftax', 'aftret_eq', 'aftret_invcapx', 'aftret_equity',\n",
    "       'pretret_noa', 'pretret_earnat', 'GProf', 'equity_invcap',\n",
    "       'debt_invcap', 'totdebt_invcap', 'capital_ratio', 'int_debt',\n",
    "       'int_totdebt', 'cash_lt', 'invt_act', 'rect_act', 'debt_at',\n",
    "       'debt_ebitda', 'short_debt', 'curr_debt', 'lt_debt', 'profit_lct',\n",
    "       'ocf_lct', 'cash_debt', 'fcf_ocf', 'lt_ppent', 'dltt_be', 'debt_assets',\n",
    "       'debt_capital', 'de_ratio', 'intcov', 'intcov_ratio', 'cash_ratio',\n",
    "       'quick_ratio', 'curr_ratio', 'cash_conversion', 'inv_turn', 'at_turn',\n",
    "       'rect_turn', 'pay_turn', 'sale_invcap', 'sale_equity', 'sale_nwc',\n",
    "       'rd_sale', 'adv_sale', 'staff_sale', 'accrual', 'ptb', 'PEG_trailing',\n",
    "       'divyield']\n",
    "\n",
    "per_share_chars = ['dividend_p','BE_p','Liability_p','cur_liability_p','LT_debt_p',\n",
    "                  'cash_p', 'total_asset_p', 'tot_debt_p', 'accrual_p', 'EBIT_p', \n",
    "                   'cur_asset_p', 'pbda_p', 'ocf_p', 'inventory_p', 'receivables_p',\n",
    "                   'Cur_debt_p', 'interest_p', 'fcf_ocf_p', 'evm_p',\n",
    "                   'sales_p', 'invcap_p', 'c_equity_p', 'rd_p', 'opmad_p', 'gpm_p','ptpm_p'\n",
    "                  ]\n",
    "\n",
    "macro_chars = ['RGDP', 'RCON', 'INDPROD', 'UNEMP']\n",
    "\n",
    "fundamental_chars = ['ret', 'prc',\n",
    "                    'EPS_true_l1_q1','EPS_true_l1_q2','EPS_true_l1_q3',\n",
    "                    'EPS_true_l1_y1','EPS_true_l1_y2',\n",
    "                    ]\n",
    "\n",
    "analyst_chars = ['EPS_ana_q1','EPS_ana_q2','EPS_ana_q3','EPS_ana_y1','EPS_ana_y2']\n",
    "\n",
    "targets = ['EPS_true_q1', 'EPS_true_q2', 'EPS_true_q3', 'EPS_true_y1', 'EPS_true_y2']\n",
    "\n",
    "df_tmp = pd.read_parquet('../data/Results/df_train_new.parquet')\n",
    "df_tmp['Year'] = df_tmp['YearMonth'].dt.year"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_data(train_data,validation_data,test_data,X_col,Y_col):\n",
    "    \n",
    "    train_X = train_data.dropna(subset=X_col+[Y_col])[X_col]\n",
    "    train_y = train_data.dropna(subset=X_col+[Y_col])[Y_col]\n",
    "    \n",
    "    valid_X = validation_data.dropna(subset=X_col+[Y_col])[X_col]\n",
    "    valid_y = validation_data.dropna(subset=X_col+[Y_col])[Y_col]\n",
    "    \n",
    "    train_valid_X = pd.concat([train_X, valid_X],axis=0)\n",
    "    train_valid_y = pd.concat([train_y, valid_y],axis=0)\n",
    "    \n",
    "    test_X = test_data.dropna(subset=X_col+[Y_col])[X_col]\n",
    "    test_y = test_data.dropna(subset=X_col+[Y_col])[Y_col]\n",
    "    \n",
    "    return train_X, train_y, valid_X, valid_y, train_valid_X, train_valid_y, test_X, test_y\n",
    "\n",
    "def GridSearch(mdl_class, param_grid, \n",
    "               train_X, train_y, valid_X, valid_y, \n",
    "               metrics, higher_better=True):\n",
    "    '''\n",
    "    GridSearch using given validation data for sklearn-type models.\n",
    "    mdl_class: e.g., RandomForestRegressor\n",
    "    param_grid: e.g.,\n",
    "        param_grid = {\n",
    "                    'n_estimators': [200,500],\n",
    "                    'max_depth' : [2,],#3,4,5,6],\n",
    "                    'max_features' : [3, 5,],# 10, 15]\n",
    "                    'random_state': [0]\n",
    "                    }\n",
    "    train_X, train_y, valid_X, valid_y: Train and Validation data\n",
    "    metrics: evaluation metrics\n",
    "    higher_better: if True, return the model with highest evaluation score\n",
    "    \n",
    "    Output: best_param(dict); best_mdl(the trained model)\n",
    "    '''\n",
    "    ## Function for Validation\n",
    "    validation_scores = []\n",
    "    # for each parameter, train a model and test on validation set\n",
    "    for params in ParameterGrid(param_grid):\n",
    "        mdl = mdl_class().set_params(**params)\n",
    "        mdl.fit(train_X, train_y)\n",
    "        validation_scores.append((params,mdl,metrics(valid_y, mdl.predict(valid_X))))\n",
    "        # break\n",
    "    # sort based on validation score    \n",
    "    validation_scores = sorted(validation_scores, key=lambda x: x[2])\n",
    "\n",
    "    if higher_better:\n",
    "        best_param = validation_scores[-1][0]\n",
    "        best_mdl = validation_scores[-1][1]\n",
    "    else:\n",
    "        best_param = validation_scores[0][0]\n",
    "        best_mdl = validation_scores[0][1]\n",
    "    return best_param, best_mdl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_col_qtr = ratio_chars + ['ret','prc','EPS_true_l1_q1'] + ['RGDP', 'RCON', 'INDPROD', 'UNEMP']\n",
    "X_col_ann = ratio_chars + ['ret','prc','EPS_true_l1_y1'] + ['RGDP', 'RCON', 'INDPROD', 'UNEMP']\n",
    "# Rolling Window\n",
    "train_window = 36\n",
    "validation_window = 12\n",
    "# Output path\n",
    "output_dir = '../data/Results/ML_variants/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "models = [\n",
    "        #  ('OLS',LinearRegression,{}),   \n",
    "         ('RF',RandomForestRegressor,{'n_estimators': [1000],'max_depth' : [4,6,8,10],\n",
    "                                                   'max_samples' : [0.05],'min_samples_leaf': [5],\n",
    "                                                   'random_state': [0],'n_jobs': [14]\n",
    "                                                  }),\n",
    "         ('LGBM',LGBMRegressor,{'n_estimators': [100, 200, 300],'learning_rate' : [0.01, 0.03, 0.07, 0.1],\n",
    "                                             'max_depth': [3, 4, 5, 6],'random_state': [0],'verbose':[-1]\n",
    "                                            }),\n",
    "        #  ('PLS',PLSRegression,{'n_components': np.arange(1,20)}),\n",
    "        #  ('LASSO',Lasso,{'alpha': np.logspace(-4,-1,20),}),\n",
    "        #  ('ENet',ElasticNet,{'alpha': np.logspace(-4,-1,20),}),\n",
    "        ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/372 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 372/372 [3:07:40<00:00, 30.27s/it]  \n",
      "100%|██████████| 372/372 [13:33:11<00:00, 131.16s/it]  \n"
     ]
    }
   ],
   "source": [
    "time_idx = sorted(df_tmp['YearMonth'].unique())\n",
    "time_idx = [i for i in time_idx if i > pd.to_datetime('1989-01-01')]\n",
    "for (mdl_abbr, mdl_class, param_grid) in models:\n",
    "    pred_value = []\n",
    "    for t in tqdm(time_idx):\n",
    "        pred_value_t = []\n",
    "        for q in [1,2,3]:\n",
    "            X_col = X_col_qtr + [f'EPS_ana_q{q}']\n",
    "            y_col = f'EPS_true_q{q}'\n",
    "            \n",
    "            ### sample splitting ###\n",
    "            train_data = df_tmp[(df_tmp['YearMonth'] >= t - MonthEnd(validation_window) - MonthEnd(train_window)) & \\\n",
    "                                (df_tmp['YearMonth'] < t - MonthEnd(validation_window)) & \\\n",
    "                                (df_tmp[f'ANNDATS_q{q}'] + MonthEnd(0) < t - MonthEnd(validation_window))\n",
    "                               ].set_index(['permno','YearMonth'])\n",
    "\n",
    "            validation_data = df_tmp[(df_tmp['YearMonth'] >= t - MonthEnd(validation_window)) & \\\n",
    "                                     (df_tmp['YearMonth'] < t) & \\\n",
    "                                     (df_tmp[f'ANNDATS_q{q}'] < t)\n",
    "                                    ].set_index(['permno','YearMonth'])\n",
    "\n",
    "            test_data = df_tmp[(df_tmp[f'ANNDATS_q{q}']>df_tmp['YearMonth']) & (df_tmp['YearMonth'] == t)].set_index(['permno','YearMonth'])\n",
    "\n",
    "            train_X, train_y, valid_X, valid_y, train_valid_X, train_valid_y, test_X, test_y = get_data(train_data, validation_data, test_data, X_col, y_col)\n",
    "            \n",
    "            ## Validation to choose Best Parameter\n",
    "            best_param, best_mdl = GridSearch(mdl_class, param_grid, \n",
    "                                          train_X, train_y, valid_X, valid_y, \n",
    "                                          r2_score)\n",
    "            # print(best_param)\n",
    "            best_mdl = mdl_class().set_params(**best_param).fit(train_valid_X, train_valid_y)\n",
    "            if mdl_abbr == 'PLS':\n",
    "                pred_value_t.append(pd.Series(best_mdl.predict(test_X)[:,0], name=f'{mdl_abbr}_EPS_Q{q}', index=test_X.index))\n",
    "            else:\n",
    "                pred_value_t.append(pd.Series(best_mdl.predict(test_X), name=f'{mdl_abbr}_EPS_Q{q}', index=test_X.index))\n",
    "            # break\n",
    "        # break\n",
    "        for y in [1,2]:\n",
    "            X_col = X_col_ann + [f'EPS_ana_y{y}']\n",
    "            y_col = f'EPS_true_y{y}'\n",
    "            if y == 2:\n",
    "                validation_window = 24\n",
    "\n",
    "            ### sample splitting ###\n",
    "            train_data = df_tmp[(df_tmp['YearMonth'] >= t - MonthEnd(validation_window) - MonthEnd(train_window)) & \\\n",
    "                            (df_tmp['YearMonth'] < t - MonthEnd(validation_window)) & \\\n",
    "                            (df_tmp[f'ANNDATS_y{y}'] + MonthEnd(0) < t - MonthEnd(validation_window))\n",
    "                           ].set_index(['permno','YearMonth'])\n",
    "\n",
    "            validation_data = df_tmp[(df_tmp['YearMonth'] >= t - MonthEnd(validation_window)) & \\\n",
    "                                     (df_tmp['YearMonth'] < t) & \\\n",
    "                                     (df_tmp[f'ANNDATS_y{y}'] < t)\n",
    "                                    ].set_index(['permno','YearMonth'])\n",
    "\n",
    "            test_data = df_tmp[(df_tmp[f'ANNDATS_y{y}']>df_tmp['YearMonth']) & (df_tmp['YearMonth'] == t)].set_index(['permno','YearMonth'])\n",
    "\n",
    "            train_X, train_y, valid_X, valid_y, train_valid_X, train_valid_y, test_X, test_y = get_data(train_data, validation_data, test_data, X_col, y_col)\n",
    "\n",
    "            ## Validation to choose Best Parameter\n",
    "            best_param, best_mdl = GridSearch(mdl_class, param_grid, \n",
    "                                          train_X, train_y, valid_X, valid_y, \n",
    "                                          r2_score)\n",
    "\n",
    "            best_mdl = mdl_class().set_params(**best_param).fit(train_valid_X, train_valid_y)\n",
    "            if mdl_abbr == 'PLS':\n",
    "                pred_value_t.append(pd.Series(best_mdl.predict(test_X)[:,0], name=f'{mdl_abbr}_EPS_Y{y}', index=test_X.index))\n",
    "            else:\n",
    "                pred_value_t.append(pd.Series(best_mdl.predict(test_X), name=f'{mdl_abbr}_EPS_Y{y}', index=test_X.index))\n",
    "            # break\n",
    "\n",
    "        pred_value_t = pd.concat(pred_value_t,axis=1,)\n",
    "        pred_value.append(pred_value_t)\n",
    "        # break\n",
    "        \n",
    "    pred_value = pd.concat(pred_value, axis=0)\n",
    "    pred_value.reset_index().to_parquet(f'{output_dir}{mdl_abbr}_pred.parquet')\n",
    "    \n",
    "    # break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "## For LASSO, ElasticNet, and PLS, we need to standardize the data\n",
    "###############################\n",
    "### Standardization: scale to the same cross-sectional std of last EPS ##\n",
    "###############################\n",
    "df_tmp = pd.read_parquet('../data/Results/df_train_new.parquet')\n",
    "df_tmp['Year'] = df_tmp['YearMonth'].dt.year\n",
    "cols = ratio_chars\n",
    "df_tmp[cols] = df_tmp.groupby('YearMonth',group_keys=False)[cols]\\\n",
    "                             .transform(lambda x: x / x.std()) \n",
    "df_tmp[cols] = df_tmp[cols].mul(df_tmp.groupby('YearMonth',group_keys=False)['EPS_true_l1_q1'].transform('std'), axis=0)\n",
    "models = [\n",
    "         ('PLS',PLSRegression,{'n_components': np.arange(1,20)}),\n",
    "         ('LASSO',Lasso,{'alpha': np.logspace(-4,0,20),}),\n",
    "         ('ENet',ElasticNet,{'alpha': np.logspace(-4,0,20),}),\n",
    "        ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 372/372 [6:02:13<00:00, 58.42s/it]  \n",
      "100%|██████████| 372/372 [1:00:02<00:00,  9.68s/it]\n",
      "100%|██████████| 372/372 [57:41<00:00,  9.30s/it]\n"
     ]
    }
   ],
   "source": [
    "time_idx = sorted(df_tmp['YearMonth'].unique())\n",
    "time_idx = [i for i in time_idx if i > pd.to_datetime('1989-01-01')]\n",
    "for (mdl_abbr, mdl_class, param_grid) in models:\n",
    "    pred_value = []\n",
    "    for t in tqdm(time_idx):\n",
    "        pred_value_t = []\n",
    "        for q in [1,2,3]:\n",
    "            X_col = X_col_qtr + [f'EPS_ana_q{q}']\n",
    "            y_col = f'EPS_true_q{q}'\n",
    "            \n",
    "            ### sample splitting ###\n",
    "            train_data = df_tmp[(df_tmp['YearMonth'] >= t - MonthEnd(validation_window) - MonthEnd(train_window)) & \\\n",
    "                                (df_tmp['YearMonth'] < t - MonthEnd(validation_window)) & \\\n",
    "                                (df_tmp[f'ANNDATS_q{q}'] + MonthEnd(0) < t - MonthEnd(validation_window))\n",
    "                               ].set_index(['permno','YearMonth'])\n",
    "\n",
    "            validation_data = df_tmp[(df_tmp['YearMonth'] >= t - MonthEnd(validation_window)) & \\\n",
    "                                     (df_tmp['YearMonth'] < t) & \\\n",
    "                                     (df_tmp[f'ANNDATS_q{q}'] < t)\n",
    "                                    ].set_index(['permno','YearMonth'])\n",
    "\n",
    "            test_data = df_tmp[(df_tmp[f'ANNDATS_q{q}']>df_tmp['YearMonth']) & (df_tmp['YearMonth'] == t)].set_index(['permno','YearMonth'])\n",
    "\n",
    "            train_X, train_y, valid_X, valid_y, train_valid_X, train_valid_y, test_X, test_y = get_data(train_data, validation_data, test_data, X_col, y_col)\n",
    "            \n",
    "            ## Validation to choose Best Parameter\n",
    "            best_param, best_mdl = GridSearch(mdl_class, param_grid, \n",
    "                                          train_X, train_y, valid_X, valid_y, \n",
    "                                          r2_score)\n",
    "            # print(best_param)\n",
    "            best_mdl = mdl_class().set_params(**best_param).fit(train_valid_X, train_valid_y)\n",
    "            # if mdl_abbr == 'PLS':\n",
    "            #     pred_value_t.append(pd.Series(best_mdl.predict(test_X)[:,0], name=f'{mdl_abbr}_EPS_Q{q}', index=test_X.index))\n",
    "            # else:\n",
    "            pred_value_t.append(pd.Series(best_mdl.predict(test_X), name=f'{mdl_abbr}_EPS_Q{q}', index=test_X.index))\n",
    "            # break\n",
    "        # break\n",
    "        for y in [1,2]:\n",
    "            X_col = X_col_ann + [f'EPS_ana_y{y}']\n",
    "            y_col = f'EPS_true_y{y}'\n",
    "            if y == 2:\n",
    "                validation_window = 24\n",
    "\n",
    "            ### sample splitting ###\n",
    "            train_data = df_tmp[(df_tmp['YearMonth'] >= t - MonthEnd(validation_window) - MonthEnd(train_window)) & \\\n",
    "                            (df_tmp['YearMonth'] < t - MonthEnd(validation_window)) & \\\n",
    "                            (df_tmp[f'ANNDATS_y{y}'] + MonthEnd(0) < t - MonthEnd(validation_window))\n",
    "                           ].set_index(['permno','YearMonth'])\n",
    "\n",
    "            validation_data = df_tmp[(df_tmp['YearMonth'] >= t - MonthEnd(validation_window)) & \\\n",
    "                                     (df_tmp['YearMonth'] < t) & \\\n",
    "                                     (df_tmp[f'ANNDATS_y{y}'] < t)\n",
    "                                    ].set_index(['permno','YearMonth'])\n",
    "\n",
    "            test_data = df_tmp[(df_tmp[f'ANNDATS_y{y}']>df_tmp['YearMonth']) & (df_tmp['YearMonth'] == t)].set_index(['permno','YearMonth'])\n",
    "\n",
    "            train_X, train_y, valid_X, valid_y, train_valid_X, train_valid_y, test_X, test_y = get_data(train_data, validation_data, test_data, X_col, y_col)\n",
    "\n",
    "            ## Validation to choose Best Parameter\n",
    "            best_param, best_mdl = GridSearch(mdl_class, param_grid, \n",
    "                                          train_X, train_y, valid_X, valid_y, \n",
    "                                          r2_score)\n",
    "\n",
    "            best_mdl = mdl_class().set_params(**best_param).fit(train_valid_X, train_valid_y)\n",
    "            # if mdl_abbr == 'PLS':\n",
    "            #     pred_value_t.append(pd.Series(best_mdl.predict(test_X)[:,0], name=f'{mdl_abbr}_EPS_Y{y}', index=test_X.index))\n",
    "            # else:\n",
    "            pred_value_t.append(pd.Series(best_mdl.predict(test_X), name=f'{mdl_abbr}_EPS_Y{y}', index=test_X.index))\n",
    "            # break\n",
    "\n",
    "        pred_value_t = pd.concat(pred_value_t,axis=1,)\n",
    "        pred_value.append(pred_value_t)\n",
    "        # break\n",
    "        \n",
    "    pred_value = pd.concat(pred_value, axis=0)\n",
    "    pred_value.reset_index().to_parquet(f'{output_dir}{mdl_abbr}_pred.parquet')\n",
    "    \n",
    "    # break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Table F.3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_tmp = pd.read_parquet('../data/Results/df_train_new.parquet')\n",
    "f_abbr_list = [('OLS_pred','OLS'),\n",
    "               ('PLS_pred','PLS'),\n",
    "               ('LASSO_pred','LASSO'),\n",
    "               ('ENet_pred','ENet'),\n",
    "               ('RF_pred','RF'),\n",
    "               ('LGBM_pred','LGBM'),\n",
    "              ]\n",
    "abbr_list = list(map(lambda x: x[1], f_abbr_list))\n",
    "\n",
    "forecast_all = []\n",
    "for f,abbr in f_abbr_list:\n",
    "    RF = pd.read_parquet(f'../data/Results/ML_variants/{f}.parquet')\n",
    "    RF = RF[['permno','YearMonth',f'{abbr}_EPS_Q1',f'{abbr}_EPS_Q2',f'{abbr}_EPS_Q3',\n",
    "             f'{abbr}_EPS_Y1',f'{abbr}_EPS_Y2',]].set_index(['permno','YearMonth'])\n",
    "    forecast_all.append(RF)\n",
    "forecast_all = reduce(lambda x,y: pd.merge(x,y,on=['permno','YearMonth'],how='outer'),\n",
    "                      forecast_all)\n",
    "forecast_all.reset_index(inplace=True)\n",
    "\n",
    "## Composite\n",
    "for i in ['Q1','Q2','Q3','Y1','Y2']:\n",
    "    forecast_all[f'Composite_EPS_{i}'] = forecast_all[[f'OLS_EPS_{i}',f'PLS_EPS_{i}',f'LASSO_EPS_{i}',\n",
    "                                                       f'ENet_EPS_{i}',f'RF_EPS_{i}',f'LGBM_EPS_{i}']].mean(axis=1)\n",
    "    \n",
    "df = df_tmp.merge(forecast_all, on=['permno','YearMonth'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ## compare Hughes\n",
    "# # Hughes et al. (2008)\n",
    "# forecast_Hughes = pd.read_parquet('../data/Results/Hughes_eps.parquet')\n",
    "# mse = []\n",
    "# for h in ['q1','q2','q3','y1','y2']:\n",
    "#     mse.append(((forecast_Hughes[f'LF_{h}'] - forecast_Hughes[f'AE_{h}'])**2).groupby(forecast_Hughes['YearMonth']).mean())\n",
    "# mse_Hughes = pd.concat(mse, axis=1)\n",
    "# mse_Hughes = mse_Hughes[mse_Hughes.index>='1989-01-01']\n",
    "# mse_Hughes.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# mse = []\n",
    "# for h in ['q1','q2','q3','y1','y2']:\n",
    "#     mse.append(((df[f'RF_EPS_{h.upper()}'] - df[f'EPS_true_{h}'])**2).groupby(df['YearMonth']).mean())\n",
    "# mse_woLAB = pd.concat(mse, axis=1)\n",
    "# mse_woLAB.mean()\n",
    "\n",
    "# mse_Hughes[1].plot(label='Hughes')\n",
    "# mse_woLAB[1].plot(label='RF')\n",
    "# plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "# mdl = []\n",
    "# for i in range(5):\n",
    "#     mdl.append(sm.OLS(endog=mse_Hughes[i]-mse_woLAB[i], \n",
    "#                       exog=[1]*len(mse_woLAB)).fit(cov_type = 'HAC', \n",
    "#                                                  cov_kwds = {'maxlags':12}))\n",
    "# summary2.summary_col(mdl, float_format='%.3f')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# test = pd.read_parquet(f'../data/Results/RF_wo_lookahead_raw_005.parquet')\n",
    "# test = test[test['YearMonth'] >= '1989-01-31']\n",
    "# for h in ['q1','q2','q3','y1','y2']:\n",
    "#     print(((test[f'RF_{h}'] - test[f'AE_{h}'])**2).groupby(test['YearMonth']).mean().mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "## EX-POST Realization\n",
    "df['REAL_EPS_Q1'] = df['EPS_true_q1']\n",
    "df['REAL_EPS_Q2'] = df['EPS_true_q2']\n",
    "df['REAL_EPS_Q3'] = df['EPS_true_q3']\n",
    "df['REAL_EPS_Y1'] = df['EPS_true_y1']\n",
    "df['REAL_EPS_Y2'] = df['EPS_true_y2']\n",
    "## Analyst Forecast\n",
    "df['ANA_EPS_Q1'] = df['EPS_ana_q1']\n",
    "df['ANA_EPS_Q2'] = df['EPS_ana_q2']\n",
    "df['ANA_EPS_Q3'] = df['EPS_ana_q3']\n",
    "df['ANA_EPS_Y1'] = df['EPS_ana_y1']\n",
    "df['ANA_EPS_Y2'] = df['EPS_ana_y2']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>OLS</th>\n",
       "      <th>PLS</th>\n",
       "      <th>LASSO</th>\n",
       "      <th>ENet</th>\n",
       "      <th>RF</th>\n",
       "      <th>LGBM</th>\n",
       "      <th>Composite</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Q1</th>\n",
       "      <td>0.055129</td>\n",
       "      <td>0.055279</td>\n",
       "      <td>0.055113</td>\n",
       "      <td>0.055224</td>\n",
       "      <td>0.053893</td>\n",
       "      <td>0.05299</td>\n",
       "      <td>0.053486</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Q2</th>\n",
       "      <td>0.076508</td>\n",
       "      <td>0.079382</td>\n",
       "      <td>0.076719</td>\n",
       "      <td>0.076876</td>\n",
       "      <td>0.074124</td>\n",
       "      <td>0.074175</td>\n",
       "      <td>0.074312</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Q3</th>\n",
       "      <td>0.106386</td>\n",
       "      <td>0.109238</td>\n",
       "      <td>0.103965</td>\n",
       "      <td>0.104667</td>\n",
       "      <td>0.100717</td>\n",
       "      <td>0.101822</td>\n",
       "      <td>0.101085</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Y1</th>\n",
       "      <td>0.546299</td>\n",
       "      <td>0.544101</td>\n",
       "      <td>0.53725</td>\n",
       "      <td>0.537542</td>\n",
       "      <td>0.529844</td>\n",
       "      <td>0.528135</td>\n",
       "      <td>0.522218</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Y2</th>\n",
       "      <td>1.854387</td>\n",
       "      <td>1.670472</td>\n",
       "      <td>1.636439</td>\n",
       "      <td>1.638248</td>\n",
       "      <td>1.589085</td>\n",
       "      <td>1.600711</td>\n",
       "      <td>1.571522</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         OLS       PLS     LASSO      ENet        RF      LGBM Composite\n",
       "Q1  0.055129  0.055279  0.055113  0.055224  0.053893   0.05299  0.053486\n",
       "Q2  0.076508  0.079382  0.076719  0.076876  0.074124  0.074175  0.074312\n",
       "Q3  0.106386  0.109238  0.103965  0.104667  0.100717  0.101822  0.101085\n",
       "Y1  0.546299  0.544101   0.53725  0.537542  0.529844  0.528135  0.522218\n",
       "Y2  1.854387  1.670472  1.636439  1.638248  1.589085  1.600711  1.571522"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "### 1. Forecast Performance\n",
    "idx = ['Q1','Q2','Q3','Y1','Y2']\n",
    "col = abbr_list + ['Composite']\n",
    "MSE = pd.DataFrame(index=idx, columns=col)\n",
    "## To make sure we have the same sample\n",
    "N_obs = pd.DataFrame(index=idx, columns=col)\n",
    "for c in col:\n",
    "    for i in idx:\n",
    "        df_ = df.dropna(subset=[f'REAL_EPS_{i}',f'{c}_EPS_{i}'])\n",
    "        MSE.loc[i, c] = df_.groupby('YearMonth').apply(lambda x: np.mean((x[f'REAL_EPS_{i}']-x[f'{c}_EPS_{i}'])**2)\n",
    "                                                        ).mean()\n",
    "        N_obs.loc[i,c] = df_.shape[0]\n",
    "(MSE).to_clipboard()\n",
    "(MSE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Panel A of Table F.3\n",
    "## MSE Compared to RF model\n",
    "rlts = []\n",
    "for i in idx:\n",
    "    MSE = pd.DataFrame(index=df_['YearMonth'].drop_duplicates().sort_values(),\n",
    "                       columns=col, dtype=float)\n",
    "    # MSE for each model at each t\n",
    "    for c in col:\n",
    "        df_ = df.dropna(subset=[f'REAL_EPS_{i}',f'{c}_EPS_{i}'])\n",
    "        MSE.loc[:, c] = ((df_[f'REAL_EPS_{i}']-df_[f'{c}_EPS_{i}'])**2).groupby(df_['YearMonth']).mean()\n",
    "        \n",
    "    # Accuracy improvement compared to RF   \n",
    "    MSE_diff = -MSE.sub(MSE['RF'],axis=0)\n",
    "\n",
    "    # test for difference\n",
    "    mdls = MSE_diff[[i for i in MSE_diff.columns if i != 'RF']].apply(lambda x: sm.OLS(endog=x, exog=[1]*len(x)).fit(cov_type='HAC',cov_kwds={'maxlags':12})).to_list()\n",
    "    rlt = summary2.summary_col(mdls, float_format='%0.3f' )\n",
    "\n",
    "    # MSE of RF model\n",
    "    rlt.insert(0, 'RF', [round(MSE['RF'].mean(), 3),''])\n",
    "    rlt.index = [i,'']\n",
    "    rlts.append(rlt)\n",
    "    \n",
    "    # break\n",
    "rlts = pd.concat(rlts, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>RF</th>\n",
       "      <th>OLS</th>\n",
       "      <th>PLS</th>\n",
       "      <th>LASSO</th>\n",
       "      <th>ENet</th>\n",
       "      <th>LGBM</th>\n",
       "      <th>Composite</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Q1</th>\n",
       "      <td>0.054</td>\n",
       "      <td>-0.001</td>\n",
       "      <td>-0.001</td>\n",
       "      <td>-0.001</td>\n",
       "      <td>-0.001</td>\n",
       "      <td>0.001</td>\n",
       "      <td>0.000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <td></td>\n",
       "      <td>(-2.31)</td>\n",
       "      <td>(-2.30)</td>\n",
       "      <td>(-1.91)</td>\n",
       "      <td>(-2.11)</td>\n",
       "      <td>(3.18)</td>\n",
       "      <td>(1.47)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Q2</th>\n",
       "      <td>0.074</td>\n",
       "      <td>-0.002</td>\n",
       "      <td>-0.005</td>\n",
       "      <td>-0.003</td>\n",
       "      <td>-0.003</td>\n",
       "      <td>-0.000</td>\n",
       "      <td>-0.000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <td></td>\n",
       "      <td>(-2.40)</td>\n",
       "      <td>(-2.86)</td>\n",
       "      <td>(-2.18)</td>\n",
       "      <td>(-2.33)</td>\n",
       "      <td>(-0.13)</td>\n",
       "      <td>(-0.30)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Q3</th>\n",
       "      <td>0.101</td>\n",
       "      <td>-0.006</td>\n",
       "      <td>-0.009</td>\n",
       "      <td>-0.003</td>\n",
       "      <td>-0.004</td>\n",
       "      <td>-0.001</td>\n",
       "      <td>-0.000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <td></td>\n",
       "      <td>(-3.32)</td>\n",
       "      <td>(-3.15)</td>\n",
       "      <td>(-2.50)</td>\n",
       "      <td>(-3.05)</td>\n",
       "      <td>(-1.31)</td>\n",
       "      <td>(-0.43)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Y1</th>\n",
       "      <td>0.53</td>\n",
       "      <td>-0.016</td>\n",
       "      <td>-0.014</td>\n",
       "      <td>-0.007</td>\n",
       "      <td>-0.008</td>\n",
       "      <td>0.002</td>\n",
       "      <td>0.008</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <td></td>\n",
       "      <td>(-2.16)</td>\n",
       "      <td>(-1.84)</td>\n",
       "      <td>(-0.87)</td>\n",
       "      <td>(-0.91)</td>\n",
       "      <td>(0.71)</td>\n",
       "      <td>(1.50)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Y2</th>\n",
       "      <td>1.589</td>\n",
       "      <td>-0.265</td>\n",
       "      <td>-0.081</td>\n",
       "      <td>-0.047</td>\n",
       "      <td>-0.049</td>\n",
       "      <td>-0.012</td>\n",
       "      <td>0.018</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <td></td>\n",
       "      <td>(-2.12)</td>\n",
       "      <td>(-3.19)</td>\n",
       "      <td>(-1.97)</td>\n",
       "      <td>(-1.95)</td>\n",
       "      <td>(-0.99)</td>\n",
       "      <td>(1.57)</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       RF      OLS      PLS    LASSO     ENet     LGBM Composite\n",
       "Q1  0.054   -0.001   -0.001   -0.001   -0.001    0.001     0.000\n",
       "           (-2.31)  (-2.30)  (-1.91)  (-2.11)   (3.18)    (1.47)\n",
       "Q2  0.074   -0.002   -0.005   -0.003   -0.003   -0.000    -0.000\n",
       "           (-2.40)  (-2.86)  (-2.18)  (-2.33)  (-0.13)   (-0.30)\n",
       "Q3  0.101   -0.006   -0.009   -0.003   -0.004   -0.001    -0.000\n",
       "           (-3.32)  (-3.15)  (-2.50)  (-3.05)  (-1.31)   (-0.43)\n",
       "Y1   0.53   -0.016   -0.014   -0.007   -0.008    0.002     0.008\n",
       "           (-2.16)  (-1.84)  (-0.87)  (-0.91)   (0.71)    (1.50)\n",
       "Y2  1.589   -0.265   -0.081   -0.047   -0.049   -0.012     0.018\n",
       "           (-2.12)  (-3.19)  (-1.97)  (-1.95)  (-0.99)    (1.57)"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rlts.to_clipboard()\n",
    "rlts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Panel B of Table F.3\n",
    "all_factor = pd.read_csv('../data/Other/ff5_factors_m.CSV')\n",
    "all_factor['YearMonth'] = pd.to_datetime(all_factor['yyyymm'], format='%Y%m') + MonthEnd(0)\n",
    "all_factor['YearMonth'] = all_factor['YearMonth'] + MonthEnd(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Var:OLS_Bias_Avg, Delete 75966 rows due to missing values, raw data 1275551 rows --> new data 1199585 rows\n",
      "Var:PLS_Bias_Avg, Delete 75966 rows due to missing values, raw data 1275551 rows --> new data 1199585 rows\n",
      "Var:LASSO_Bias_Avg, Delete 75966 rows due to missing values, raw data 1275551 rows --> new data 1199585 rows\n",
      "Var:ENet_Bias_Avg, Delete 75966 rows due to missing values, raw data 1275551 rows --> new data 1199585 rows\n",
      "Var:RF_Bias_Avg, Delete 75966 rows due to missing values, raw data 1275551 rows --> new data 1199585 rows\n",
      "Var:LGBM_Bias_Avg, Delete 75966 rows due to missing values, raw data 1275551 rows --> new data 1199585 rows\n",
      "Var:Composite_Bias_Avg, Delete 75966 rows due to missing values, raw data 1275551 rows --> new data 1199585 rows\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>OLS</th>\n",
       "      <th>PLS</th>\n",
       "      <th>LASSO</th>\n",
       "      <th>ENet</th>\n",
       "      <th>RF</th>\n",
       "      <th>LGBM</th>\n",
       "      <th>Composite</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Ret</th>\n",
       "      <td>-0.16</td>\n",
       "      <td>-0.24</td>\n",
       "      <td>-0.12</td>\n",
       "      <td>-0.21</td>\n",
       "      <td>-0.34</td>\n",
       "      <td>-0.39</td>\n",
       "      <td>-0.15</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <td>(-0.48)</td>\n",
       "      <td>(-0.66)</td>\n",
       "      <td>(-0.34)</td>\n",
       "      <td>(-0.59)</td>\n",
       "      <td>(-0.89)</td>\n",
       "      <td>(-1.21)</td>\n",
       "      <td>(-0.37)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>CAPM</th>\n",
       "      <td>-0.59</td>\n",
       "      <td>-0.65</td>\n",
       "      <td>-0.56</td>\n",
       "      <td>-0.67</td>\n",
       "      <td>-0.84</td>\n",
       "      <td>-0.77</td>\n",
       "      <td>-0.60</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <td>(-2.12)</td>\n",
       "      <td>(-2.09)</td>\n",
       "      <td>(-1.77)</td>\n",
       "      <td>(-2.16)</td>\n",
       "      <td>(-2.57)</td>\n",
       "      <td>(-2.62)</td>\n",
       "      <td>(-1.80)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>FF3</th>\n",
       "      <td>-0.66</td>\n",
       "      <td>-0.73</td>\n",
       "      <td>-0.63</td>\n",
       "      <td>-0.74</td>\n",
       "      <td>-0.95</td>\n",
       "      <td>-0.88</td>\n",
       "      <td>-0.69</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <td>(-3.24)</td>\n",
       "      <td>(-3.07)</td>\n",
       "      <td>(-2.77)</td>\n",
       "      <td>(-3.44)</td>\n",
       "      <td>(-4.54)</td>\n",
       "      <td>(-4.83)</td>\n",
       "      <td>(-3.01)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>FF5</th>\n",
       "      <td>-0.18</td>\n",
       "      <td>-0.23</td>\n",
       "      <td>-0.19</td>\n",
       "      <td>-0.30</td>\n",
       "      <td>-0.51</td>\n",
       "      <td>-0.55</td>\n",
       "      <td>-0.23</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <td>(-0.64)</td>\n",
       "      <td>(-0.77)</td>\n",
       "      <td>(-0.63)</td>\n",
       "      <td>(-1.04)</td>\n",
       "      <td>(-1.80)</td>\n",
       "      <td>(-2.03)</td>\n",
       "      <td>(-0.71)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>FFC6</th>\n",
       "      <td>0.33</td>\n",
       "      <td>0.28</td>\n",
       "      <td>0.33</td>\n",
       "      <td>0.21</td>\n",
       "      <td>-0.02</td>\n",
       "      <td>-0.04</td>\n",
       "      <td>0.32</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <td>(1.54)</td>\n",
       "      <td>(1.10)</td>\n",
       "      <td>(1.40)</td>\n",
       "      <td>(0.85)</td>\n",
       "      <td>(-0.09)</td>\n",
       "      <td>(-0.19)</td>\n",
       "      <td>(1.33)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>HXZ</th>\n",
       "      <td>0.19</td>\n",
       "      <td>0.09</td>\n",
       "      <td>0.18</td>\n",
       "      <td>0.06</td>\n",
       "      <td>-0.23</td>\n",
       "      <td>-0.21</td>\n",
       "      <td>0.14</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <td>(0.60)</td>\n",
       "      <td>(0.29)</td>\n",
       "      <td>(0.55)</td>\n",
       "      <td>(0.18)</td>\n",
       "      <td>(-0.76)</td>\n",
       "      <td>(-0.62)</td>\n",
       "      <td>(0.40)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>HMXZ</th>\n",
       "      <td>0.39</td>\n",
       "      <td>0.36</td>\n",
       "      <td>0.37</td>\n",
       "      <td>0.24</td>\n",
       "      <td>-0.01</td>\n",
       "      <td>-0.02</td>\n",
       "      <td>0.37</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <td>(1.18)</td>\n",
       "      <td>(1.03)</td>\n",
       "      <td>(1.12)</td>\n",
       "      <td>(0.69)</td>\n",
       "      <td>(-0.03)</td>\n",
       "      <td>(-0.07)</td>\n",
       "      <td>(1.07)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>SY</th>\n",
       "      <td>0.48</td>\n",
       "      <td>0.44</td>\n",
       "      <td>0.45</td>\n",
       "      <td>0.33</td>\n",
       "      <td>0.06</td>\n",
       "      <td>0.19</td>\n",
       "      <td>0.42</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <td>(1.94)</td>\n",
       "      <td>(1.65)</td>\n",
       "      <td>(1.62)</td>\n",
       "      <td>(1.26)</td>\n",
       "      <td>(0.29)</td>\n",
       "      <td>(0.77)</td>\n",
       "      <td>(1.63)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>DHS</th>\n",
       "      <td>0.50</td>\n",
       "      <td>0.47</td>\n",
       "      <td>0.52</td>\n",
       "      <td>0.41</td>\n",
       "      <td>0.22</td>\n",
       "      <td>0.22</td>\n",
       "      <td>0.52</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <td>(1.29)</td>\n",
       "      <td>(1.09)</td>\n",
       "      <td>(1.28)</td>\n",
       "      <td>(0.98)</td>\n",
       "      <td>(0.53)</td>\n",
       "      <td>(0.61)</td>\n",
       "      <td>(1.19)</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "          OLS      PLS    LASSO     ENet       RF     LGBM Composite\n",
       "Ret     -0.16    -0.24    -0.12    -0.21    -0.34    -0.39     -0.15\n",
       "      (-0.48)  (-0.66)  (-0.34)  (-0.59)  (-0.89)  (-1.21)   (-0.37)\n",
       "CAPM    -0.59    -0.65    -0.56    -0.67    -0.84    -0.77     -0.60\n",
       "      (-2.12)  (-2.09)  (-1.77)  (-2.16)  (-2.57)  (-2.62)   (-1.80)\n",
       "FF3     -0.66    -0.73    -0.63    -0.74    -0.95    -0.88     -0.69\n",
       "      (-3.24)  (-3.07)  (-2.77)  (-3.44)  (-4.54)  (-4.83)   (-3.01)\n",
       "FF5     -0.18    -0.23    -0.19    -0.30    -0.51    -0.55     -0.23\n",
       "      (-0.64)  (-0.77)  (-0.63)  (-1.04)  (-1.80)  (-2.03)   (-0.71)\n",
       "FFC6     0.33     0.28     0.33     0.21    -0.02    -0.04      0.32\n",
       "       (1.54)   (1.10)   (1.40)   (0.85)  (-0.09)  (-0.19)    (1.33)\n",
       "HXZ      0.19     0.09     0.18     0.06    -0.23    -0.21      0.14\n",
       "       (0.60)   (0.29)   (0.55)   (0.18)  (-0.76)  (-0.62)    (0.40)\n",
       "HMXZ     0.39     0.36     0.37     0.24    -0.01    -0.02      0.37\n",
       "       (1.18)   (1.03)   (1.12)   (0.69)  (-0.03)  (-0.07)    (1.07)\n",
       "SY       0.48     0.44     0.45     0.33     0.06     0.19      0.42\n",
       "       (1.94)   (1.65)   (1.62)   (1.26)   (0.29)   (0.77)    (1.63)\n",
       "DHS      0.50     0.47     0.52     0.41     0.22     0.22      0.52\n",
       "       (1.29)   (1.09)   (1.28)   (0.98)   (0.53)   (0.61)    (1.19)"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "idx = ['Q1','Q2','Q3','Y1','Y2']\n",
    "col = abbr_list + ['Composite'] \n",
    "num_level = 5\n",
    "factor_dict = {'Ret': ['ones'],\n",
    "               'CAPM':['ones','Mkt_RF'],\n",
    "               'FF3': ['ones','Mkt_RF','SMB','HML'],\n",
    "               'FF5': ['ones','Mkt_RF','SMB', 'HML', 'RMW', 'CMA'],\n",
    "               'FFC6':['ones','Mkt_RF','SMB', 'HML', 'RMW', 'CMA','MOM'],\n",
    "               'HXZ':['ones','R_MKT','R_ME','R_IA','R_ROE'],\n",
    "               'HMXZ':['ones','R_MKT','R_ME','R_IA','R_ROE','R_EG'],\n",
    "               'SY':['ones','Mkt_RF','SMB_SY','MGMT', 'PERF'],\n",
    "               'DHS':['ones','Mkt_RF','PEAD', 'FIN'],\n",
    "               }\n",
    "\n",
    "rlts = []\n",
    "for c in col:\n",
    "    for i in idx:\n",
    "        df[f'{c}_Bias_{i}'] = (df[f'ANA_EPS_{i}'] - df[f'{c}_EPS_{i}'])/df['prc_l1']\n",
    "        \n",
    "    # Average Bias\n",
    "    df[f'{c}_Bias_Avg'] = df[[f'{c}_Bias_Q1',f'{c}_Bias_Q2',f'{c}_Bias_Q3',\n",
    "                              f'{c}_Bias_Y1',f'{c}_Bias_Y2']].mean(axis=1)\n",
    "    \n",
    "    nonNA = (~df[[f'{c}_Bias_Q1',f'{c}_Bias_Q2',f'{c}_Bias_Q3',\n",
    "                  f'{c}_Bias_Y1',f'{c}_Bias_Y2']].isna()).sum(axis=1)\n",
    "    df[f'{c}_Bias_Avg'] = np.where(nonNA > 1,\n",
    "                                   df[f'{c}_Bias_Avg'],\n",
    "                                   np.nan)\n",
    "\n",
    "    sort_var = f'{c}_Bias_Avg'\n",
    "    _,vwret1 = utils.SingleSort(df,'PERMNO', 'YearMonth', \n",
    "                                    sort_var, 'bh1m', num_level, \n",
    "                                    'ME', quantile_filter=None)\n",
    "    result = utils.SingleSort_RetAna(_,vwret1,'YearMonth',factor_data=all_factor,factor_dict=factor_dict,lag=12)\n",
    "    result = result['H-L']\n",
    "\n",
    "    result.name = c\n",
    "    rlts.append(result)\n",
    "    # break\n",
    "rlts = pd.concat(rlts,axis=1)\n",
    "rlts.to_clipboard()\n",
    "rlts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
